import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.worker.output import Logprob, LogprobsTensors


class DummyModelConfig:
    def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
        self.max_logprobs = max_logprobs
        self.ori_vocab_size = ori_vocab_size
        self.enable_logprob = enable_logprob


class DummyCacheConfig:
    def __init__(self, enable_prefix_caching=False):
        self.enable_prefix_caching = enable_prefix_caching


class DummyLLMEngineConfig:
    def __init__(self, model_config=None, cache_config=None):
        self.model_config = model_config or DummyModelConfig()
        self.cache_config = cache_config or DummyCacheConfig()


class DummyLLMEngine:
    def __init__(self, model_config=None, cache_config=None):
        self.cfg = DummyLLMEngineConfig(model_config, cache_config)
        self.data_processor = MagicMock()
        # Mock tokenizer with sp_model attribute
        self.data_processor.tokenizer = MagicMock()
        self.data_processor.tokenizer.sp_model = MagicMock()
        self.data_processor.tokenizer.sp_model.__len__ = MagicMock(return_value=100)
        self.data_processor.tokenizer.vocab = MagicMock()
        self.data_processor.tokenizer.vocab.__len__ = MagicMock(return_value=100)
        self.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
        self.add_requests = MagicMock()


@pytest.fixture
def mock_llm():
    llm = LLM.__new__(LLM)
    llm.llm_engine = DummyLLMEngine()
    return llm


@pytest.fixture
def mock_llm_with_prefix_caching():
    llm = LLM.__new__(LLM)
    llm.llm_engine = DummyLLMEngine(cache_config=DummyCacheConfig(enable_prefix_caching=True))
    return llm


def test_prompt_logprobs_not_supported_with_stream(mock_llm):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(prompt_logprobs=5)
        with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
            mock_llm._add_request(["hi"], sampling, stream=True)


def test_prompt_logprobs_not_supported_with_prefix_caching(mock_llm_with_prefix_caching):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(prompt_logprobs=5)
        with pytest.raises(ValueError, match="prompt_logprobs is not supported with prefix caching enabled"):
            mock_llm_with_prefix_caching._add_request(["hi"], sampling)


def test_num_logprobs_exceeds_max(mock_llm):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs > 20
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(logprobs=20)
        with pytest.raises(ValueError, match="Number of logprobs requested"):
            mock_llm._add_request(["hi"], sampling)


def test_max_logprobs_exceeds_vocab_size(mock_llm):
    # Test case where max_logprobs > ori_vocab_size
    mock_llm.llm_engine.cfg.model_config.max_logprobs = 150  # > vocab size (100)
    with pytest.raises(ValueError, match="max_logprobs \\(150\\) exceeds vocabulary size \\(100\\)"):
        mock_llm._add_request(["hi"], SamplingParams())


def test_max_logprobs_less_than_minus_one(mock_llm):
    # Test case where max_logprobs < -1
    mock_llm.llm_engine.cfg.model_config.max_logprobs = -2
    with pytest.raises(ValueError, match="max_logprobs \\(-2\\) can't be less than -1"):
        mock_llm._add_request(["hi"], SamplingParams())


def test_logprobs_minus_one_uses_vocab_size(mock_llm):
    # Test that logprobs=-1 uses vocab size
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(logprobs=-1)
        mock_llm.llm_engine.cfg.model_config.max_logprobs = -1  # Allow unlimited
        mock_llm._add_request(["hi"], sampling)
        mock_llm.llm_engine.add_requests.assert_called_once()


def test_num_prompt_logprobs_exceeds_max(mock_llm):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(prompt_logprobs=20)
        with pytest.raises(ValueError, match="Number of logprobs requested"):
            mock_llm._add_request(["hi"], sampling)


def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs=-1
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(logprobs=-1)
        mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
        mock_llm._add_request(["hi"], sampling)
        mock_llm.llm_engine.add_requests.assert_called_once()
        # Get the first argument (tasks) which should be a dict
        call_args = mock_llm.llm_engine.add_requests.call_args
        tasks = call_args[0][0]  # First positional argument
        assert isinstance(tasks, dict)
        assert "prompt" in tasks
        assert "request_id" in tasks


def test_prompt_logprobs_equal_to_minus_one(mock_llm):
    # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support and allow -1
    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(prompt_logprobs=-1)
        mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
        mock_llm._add_request(["hi"], sampling)
        mock_llm.llm_engine.add_requests.assert_called_once()


def test_dynamic_vocab_size_from_sp_model(mock_llm):
    # Test that ori_vocab_size is dynamically obtained from sp_model
    mock_llm.llm_engine.data_processor.tokenizer.sp_model.__len__.return_value = 200
    mock_llm.llm_engine.cfg.model_config.max_logprobs = -1

    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(logprobs=-1)
        mock_llm._add_request(["hi"], sampling)
        # Should use the dynamic vocab size (200)
        mock_llm.llm_engine.add_requests.assert_called_once()


def test_dynamic_vocab_size_from_vocab_fallback(mock_llm):
    # Test fallback to vocab when sp_model is not available
    del mock_llm.llm_engine.data_processor.tokenizer.sp_model
    mock_llm.llm_engine.data_processor.tokenizer.vocab.__len__.return_value = 300
    mock_llm.llm_engine.cfg.model_config.max_logprobs = -1

    with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
        sampling = SamplingParams(logprobs=-1)
        mock_llm._add_request(["hi"], sampling)
        # Should use the vocab size (300)
        mock_llm.llm_engine.add_requests.assert_called_once()


def test_build_prompt_logprobs_basic(mock_llm):
    # 构造 2 个 token，每个 token 对应 3 个 logprob 值
    token_ids = np.array([[1, 2, 3], [4, 5, 6]])
    logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]])
    ranks = np.array([1, 2])
    tensors = LogprobsTensors(token_ids, logprobs, ranks)

    result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2)

    # 检查结果格式
    assert isinstance(result, list)
    assert len(result) == 3
    for pos_dict in result:
        if pos_dict is not None:
            assert isinstance(pos_dict, dict)
            for logprob_obj in pos_dict.values():
                assert isinstance(logprob_obj, Logprob)
                assert logprob_obj.decoded_token.startswith("TOKEN_")


def test_build_prompt_logprobs_handles_minus_one(mock_llm):
    token_ids = np.array([[7, 8]])
    logprobs = np.array([[-0.9, -1.0]])
    ranks = np.array([1])
    tensors = LogprobsTensors(token_ids, logprobs, ranks)

    result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)

    assert isinstance(result, list)
    assert len(result) == 2
    pos_dict = result[1]
    assert 7 in pos_dict
    assert pos_dict[7].decoded_token == "TOKEN_7"
